# *_*coding:utf-8 *_*
# @Author : Reggie
# @Time : 2021/11/26 18:30
import logging
import os
import platform
import sys
import time
from concurrent.futures._base import as_completed
from concurrent.futures.thread import ThreadPoolExecutor
from pathlib import Path

import paramiko

if platform.system().lower() != "windows":
    raise ValueError("Only supported on Windows")

RemoteSnFilePath = "/etc/sysconfig/uraid/sn"
RemoteKeyFilePath = "/etc/sysconfig/uraid/key"
LOG_LEVEL = "INFO"

env_debug = True if os.environ.get("DEBUG") else False
if env_debug:
    LOG_LEVEL = "DEBUG"

nolog = True if os.environ.get("nolog") else False
if not nolog:
    root_logger = logging.getLogger()
    stream_handle = logging.StreamHandler()
    formatter = logging.Formatter("%(asctime)s -- %(levelname)s --"
                                  " %(message)s (%(funcName)s in %(filename)s %(lineno)d)")
    stream_handle.setFormatter(formatter)
    root_logger.addHandler(stream_handle)
    root_logger.setLevel(LOG_LEVEL)
    logger = logging.getLogger("paramiko")
    if logger:
        logger.setLevel("WARNING")


class SSHConnection(object):
    def __init__(self, host, port=22, username='root', pwd='user@dev', use_callback=False):
        self.host = host
        self.port = port
        self.username = username
        self.pwd = pwd
        self.use_callback = use_callback
        self.__transport = None

    def __enter__(self):
        if self.__transport is not None:
            raise RuntimeError("Already connected")
        self.connect()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def connect(self):
        if self.__transport is not None and self.__transport.active:
            return
        transport = paramiko.Transport((self.host, self.port))
        transport.connect(username=self.username, password=self.pwd)
        self.__transport = transport

    def close(self):
        self.__transport.close()
        self.__transport = None

    def translate_byte(self, B):
        B = float(B)
        KB = float(1024)
        MB = float(KB ** 2)
        GB = float(MB ** 2)
        TB = float(GB ** 2)
        if B < KB:
            return '{} {}'.format(B, 'bytes' if B > 1 else 'byte')
        elif KB < B < MB:
            return '{:.2f} KB'.format(B / KB)
        elif MB < B < GB:
            return '{:.2f} MB'.format(B / MB)
        elif GB < B < TB:
            return '{:.2f} GB'.format(B / GB)
        else:
            return '{:.2f} TB'.format(B / TB)

    def call_back(self, transferred, toBeTransferred, suffix=''):
        bar_len = 100
        filled_len = int(round(bar_len * transferred / float(toBeTransferred)))
        percents = round(100.0 * transferred / float(toBeTransferred), 1)
        bar = '\033[32;1m%s\033[0m' % '=' * filled_len + '-' * (bar_len - filled_len)
        sys.stdout.write('[%s] %s%s %s\r\n' % (bar, '\033[32;1m%s\033[0m' % percents, '%', suffix))
        sys.stdout.flush()

    def sftp_down_file(self, server_path, local_path):
        sftp = paramiko.SFTPClient.from_transport(self.__transport)
        result = sftp.get(server_path, local_path, callback=self.call_back if self.use_callback else None)
        return result

    def sftp_upload_file(self, local_path, server_path):
        sftp = paramiko.SFTPClient.from_transport(self.__transport)
        result = sftp.put(local_path, server_path, callback=self.call_back if self.use_callback else None)
        return result

    def cmd(self, command, encode="utf-8"):
        ssh = paramiko.SSHClient()
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh._transport = self.__transport
        # 执行命令
        stdin, stdout, stderr = ssh.exec_command(command)
        # 获取命令结果
        result = stdout.read()
        return result.decode(encode), stdout, stderr


def gen_node_uus_sn(host, port=22):
    # sn_path = Path("uus_sn")
    # sn_path.mkdir(parents=True, exist_ok=True)
    # key_path = Path("uus_key")
    # key_path.mkdir(parents=True, exist_ok=True)
    remote_sn_file_path = Path(RemoteSnFilePath)
    remote_key_file_path = Path(RemoteKeyFilePath)
    # 下载 sn
    with SSHConnection(host, port) as client:
        logging.info(f"{'*' * 50}")
        logging.info(f"start gen {host} sn key")
        # 上传 sn 到 license 生成服务器
        # local_sn_file_path = sn_path.joinpath(f"sn{client_ip.replace('.', '')}")
        local_sn_file_path = Path(r"\\172.18.50.50\lic").joinpath(f"sn{host.replace('.', '')}")
        logging.info(f"local sn file path: {local_sn_file_path}")
        logging.info(f"remote sn file path: {remote_sn_file_path}")
        client.sftp_down_file(remote_sn_file_path.as_posix(), local_sn_file_path.as_posix())
        logging.info(f"sn: {local_sn_file_path.read_text()}")

        key_file = None
        while not key_file:
            for smb_file in Path(r"\\172.18.50.50\lic").iterdir():
                client_ip_replace = host.replace(".", "")
                if not smb_file.name.startswith("key"):
                    continue
                if not smb_file.name.endswith(client_ip_replace):
                    continue
                key_file = smb_file
                logging.info(f"find key file: {key_file.as_posix()}")
                logging.info(f"key: {key_file.read_text()}")
            time.sleep(0.2)

        # 上传 license 到 sn 服务器
        client.sftp_upload_file(key_file.as_posix(), remote_key_file_path.as_posix())
        logging.info(f"upload key file: {remote_key_file_path.as_posix()}")
        out, err, rc = client.cmd("ucli svc restart v-meta")
        if rc != 0:
            logging.info("Restart v-meta success")
        else:
            logging.error("Restart v-meta failed")


def main(ips):
    if isinstance(ips, str):
        ip_list = ips.split()
    elif isinstance(ips, list):
        ip_list = ips
    else:
        raise ValueError(f"{ips} type error. Must be list or str")
    executor = ThreadPoolExecutor(max_workers=10)
    tasks = [executor.submit(gen_node_uus_sn, ip) for ip in ip_list]

    for future in as_completed(tasks):
        data = future.result()
        print("in main: get page {}s success".format(data))


if __name__ == '__main__':
    ips = "172.16.120.134 172.16.120.135 172.16.120.136"
    ips = "10.200.2.100 10.200.2.101"
    ips = "10.200.2.60 10.200.2.100 10.200.2.101"
    ips = "10.200.102.14 10.200.102.19 10.200.102.25"
    main(ips)
